# agnews_znkd.py
#
# Full ZNKD pipeline for AG-News:
# - Teacher: 10-qubit VQC (1024-dim TF-IDF → PCA → angle encoding)
# - ZNE-corrected energies per class
# - Tanh-stabilized soft targets
# - Student: 6-qubit QNN via regression

import numpy as np
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

from qiskit_aer import AerSimulator
from qiskit.utils import QuantumInstance
from qiskit.circuit.library import EfficientSU2
from qiskit import QuantumCircuit

from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC, NeuralNetworkRegressor

from agnews_zne_utils import zne_expectation_zero

SEED = 123
N_CLASSES = 4
TEACHER_QUBITS = 10
STUDENT_QUBITS = 6
TFIDF_DIM = 1024
PCA_DIM = TEACHER_QUBITS


def load_agnews_pca():
    ds = load_dataset("ag_news")
    texts = ds["train"]["text"] + ds["test"]["text"]
    labels = ds["train"]["label"] + ds["test"]["label"]

    tfidf = TfidfVectorizer(max_features=TFIDF_DIM, stop_words="english")
    X = tfidf.fit_transform(texts).toarray().astype(np.float32)
    y = np.array(labels, dtype=int)

    pca = PCA(n_components=PCA_DIM, random_state=SEED)
    X = pca.fit_transform(X)

    X = np.pi * (X - X.min()) / (X.max() - X.min() + 1e-12)

    return train_test_split(X, y, test_size=0.2,
                            random_state=SEED, stratify=y)


def make_feature_map(num_qubits):
    def fm(x):
        qc = QuantumCircuit(num_qubits)
        for i,val in enumerate(x[:num_qubits]):
            qc.ry(float(val), i)
        return qc
    return fm


def build_teacher_vqc():
    ansatz = EfficientSU2(TEACHER_QUBITS, reps=2)
    backend = AerSimulator(seed_simulator=SEED)
    qinst = QuantumInstance(backend=backend,
                            seed_simulator=SEED,
                            seed_transpiler=SEED)

    return VQC(
        feature_map=make_feature_map(TEACHER_QUBITS),
        ansatz=ansatz,
        optimizer="COBYLA",
        quantum_instance=qinst,
        num_classes=N_CLASSES
    )


def build_student_regressor():
    ansatz = EfficientSU2(STUDENT_QUBITS, reps=1)

    qnn = SamplerQNN(
        circuit=ansatz,
        input_params=ansatz.parameters[:STUDENT_QUBITS],
        weight_params=ansatz.parameters[STUDENT_QUBITS:],
        sparse=False
    )

    backend = AerSimulator(seed_simulator=SEED)
    qinst = QuantumInstance(backend=backend,
                            seed_simulator=SEED,
                            seed_transpiler=SEED)

    return NeuralNetworkRegressor(
        neural_network=qnn,
        loss="l2",
        optimizer="COBYLA",
        quantum_instance=qinst,
    )


def compute_zne_tanh_targets(vqc, X, base_eps=0.01, tau=1.0):
    all_targets = []
    for i, x in enumerate(X):
        if (i+1) % 200 == 0:
            print(f"[ZNE targets] {i+1}/{len(X)} samples")

        energies = zne_expectation_zero(vqc, x, base_eps=base_eps)
        stabilized = np.tanh(energies / tau)
        all_targets.append(stabilized)

    return np.stack(all_targets, axis=0)


def teacher_acc(vqc, X, y):
    return (vqc.predict(X) == y).mean()


def student_acc(reg, X, y):
    preds = reg.predict(X)
    y_hat = np.argmin(preds, axis=1)
    return (y_hat == y).mean()


def main():
    print("Loading AG-News + PCA …")
    X_train, X_test, y_train, y_test = load_agnews_pca()

    print("Training TEACHER VQC …")
    teacher = build_teacher_vqc()
    teacher.fit(X_train, y_train)

    print("\nTeacher accuracy:", teacher_acc(teacher, X_test, y_test))

    print("\nComputing ZNE-based tanh targets …")
    targets = compute_zne_tanh_targets(teacher, X_train, base_eps=0.01, tau=1.0)

    print("Training STUDENT QNN regressor …")
    student = build_student_regressor()
    student.fit(X_train, targets)

    print("\nDistilled Student Accuracy:", student_acc(student, X_test, y_test))


if __name__ == "__main__":
    main()
